import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
import os


def OWT(config=None):
    gpu_id = int(os.getenv("RANK", -1))
    assert gpu_id != -1
    world_size = int(os.getenv("WORLD_SIZE", 1))
    block_size = min(config.block_size, config.tokenizer.model_max_length)
    val_file_path = os.path.join(
        config.train.path,
        config.tokenizer.name + "-val.bin",
    )
    test_file_path = os.path.join(
        config.train.path,
        config.tokenizer.name + "-test.bin",
    )
    val_file = np.memmap(val_file_path, dtype=np.uint16, mode="r")
    test_file = np.memmap(test_file_path, dtype=np.uint16, mode="r")

    def split_by_block_size_and_node(file, cur_gpu_id, cur_world_size):
        num_samples = (len(file) - 1) // block_size  # drop last
        per_gpu_samples = (num_samples + cur_world_size - 1) // cur_world_size
        st = cur_gpu_id * per_gpu_samples * block_size
        ed = min(
            num_samples * block_size, (cur_gpu_id + 1) * per_gpu_samples * block_size
        )
        input_ids = []
        labels = []
        for i in range(st, ed, block_size):
            x = torch.from_numpy((file[i : i + block_size]).astype(np.int64))
            y = torch.from_numpy((file[i + 1 : i + 1 + block_size]).astype(np.int64))
            input_ids.append(x)
            labels.append(y)
        input_ids = torch.stack(input_ids, dim=0)
        labels = torch.stack(labels, dim=0)
        print(
            "rank: {}, input_id shape {}, labels shape {}".format(
                cur_gpu_id, input_ids.shape, labels.shape
            )
        )
        data_set = TensorDataset(input_ids, labels)
        data_loader = DataLoader(
            data_set,
            num_workers=config.num_workers,
            batch_size=config.test.test_batch,
            pin_memory=True,
        )
        return data_loader

    val_loader = split_by_block_size_and_node(val_file, gpu_id, world_size)
    test_loader = split_by_block_size_and_node(test_file, 0, 1)
    return val_loader, test_loader
